# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import numpy as np

EXPECTED_ANSWERS = [
    {  # test answers for PyTorch 2.0
        "integration_segmentation_3d": {
            "losses": [
                0.5430086106061935,
                0.47010003924369814,
                0.4453376233577728,
                0.451901963353157,
                0.4398456811904907,
                0.43450237810611725,
            ],
            "best_metric": 0.9329540133476257,
            "infer_metric": 0.9330471754074097,
            "output_sums": [
                0.14212507078546172,
                0.15199039602949577,
                0.15133471939291526,
                0.13967984811021827,
                0.18831614355832332,
                0.1694076821827231,
                0.14663931509271658,
                0.16788710637623733,
                0.1569452710008219,
                0.17907130698392254,
                0.16244092698688475,
                0.1679350345855819,
                0.14437674754879065,
                0.11355098478396568,
                0.161660275855964,
                0.20082478187698194,
                0.17575491677668853,
                0.0974593860605401,
                0.19366775441539907,
                0.20293016863409002,
                0.19610441127101647,
                0.20812173772459808,
                0.16184212006067655,
                0.13185211452732482,
                0.14824716961304257,
                0.14229818359602905,
                0.23141282114085215,
                0.1609268635938338,
                0.14825300029123678,
                0.10286266811772046,
                0.11873484714087054,
                0.1296615212510262,
                0.11386621034856693,
                0.15203351148564773,
                0.16300823766585265,
                0.1936726544485426,
                0.2227251185536394,
                0.18067789917505797,
                0.19005874127683337,
                0.07462121515702229,
            ],
        }
    },
    {  # test answers for cuda 12
        "integration_segmentation_3d": {
            "losses": [
                0.5362162500619888,
                0.4704935997724533,
                0.4335438072681427,
                0.4507470965385437,
                0.45187077224254607,
                0.4363303750753403,
            ],
            "best_metric": 0.9334161877632141,
            "infer_metric": 0.9335371851921082,
            "output_sums": [
                0.14210400101844414,
                0.1521489829835625,
                0.15127096315211278,
                0.13992817339153868,
                0.1884040828001848,
                0.16929503899789516,
                0.14662516818085808,
                0.16803982264111883,
                0.1570018930834878,
                0.17916684191571494,
                0.1626376090146162,
                0.1680113549677271,
                0.1446708736188978,
                0.1140289628362559,
                0.16191495673888556,
                0.20066696225510708,
                0.17581812459936835,
                0.09836918048666465,
                0.19355007524499268,
                0.20291004237066343,
                0.19606797329772976,
                0.2082113232291515,
                0.16189564397603906,
                0.13203990336741953,
                0.14849477534402156,
                0.14250633066863938,
                0.23139529505006795,
                0.16079877619802546,
                0.14821067071610583,
                0.10302449386782145,
                0.11876349315302756,
                0.13006925219380802,
                0.11431448379763984,
                0.15254606148569302,
                0.16317147221367873,
                0.19376668030880526,
                0.22260597124465822,
                0.18085088544070227,
                0.19010916899493174,
                0.07748195410499427,
            ],
        }
    },
    {  # test answers for 23.02
        "integration_segmentation_3d": {
            "losses": [
                0.5401686698198318,
                0.4789864182472229,
                0.4417317628860474,
                0.44183324575424193,
                0.4418945342302322,
                0.44213996827602386,
            ],
            "best_metric": 0.9316274523735046,
            "infer_metric": 0.9321609735488892,
            "output_sums": [
                0.14212507078546172,
                0.15199039602949577,
                0.15133471939291526,
                0.13967984811021827,
                0.18831614355832332,
                0.1694076821827231,
                0.14663931509271658,
                0.16788710637623733,
                0.1569452710008219,
                0.17907130698392254,
                0.16244092698688475,
                0.1679350345855819,
                0.14437674754879065,
                0.11355098478396568,
                0.161660275855964,
                0.20082478187698194,
                0.17575491677668853,
                0.0974593860605401,
                0.19366775441539907,
                0.20293016863409002,
                0.19610441127101647,
                0.20812173772459808,
                0.16184212006067655,
                0.13185211452732482,
                0.14824716961304257,
                0.14229818359602905,
                0.23141282114085215,
                0.1609268635938338,
                0.14825300029123678,
                0.10286266811772046,
                0.11873484714087054,
                0.1296615212510262,
                0.11386621034856693,
                0.15203351148564773,
                0.16300823766585265,
                0.1936726544485426,
                0.2227251185536394,
                0.18067789917505797,
                0.19005874127683337,
                0.07462121515702229,
            ],
        }
    },
    {  # test answers for 24.03
        "integration_segmentation_3d": {
            "losses": [
                0.5442982316017151,
                0.4741817444562912,
                0.4535954713821411,
                0.44163046181201937,
                0.4307525992393494,
                0.428487154841423,
            ],
            "best_metric": 0.9314384460449219,
            "infer_metric": 0.9315622448921204,
            "output_sums": [
                0.14268704426414708,
                0.1528672845845743,
                0.1521782248125706,
                0.14028769128068194,
                0.1889830671664784,
                0.16999075690664475,
                0.14736282992708227,
                0.16877952654821815,
                0.15779597155181269,
                0.17987829927082263,
                0.16320253928314676,
                0.16854299322173155,
                0.14497470986956967,
                0.11437140546369519,
                0.1624117412960871,
                0.20156009294443875,
                0.1764654154256958,
                0.0982348259217418,
                0.1942436068604293,
                0.20359421536407518,
                0.19661953116976483,
                0.2088326101468625,
                0.16273043545239807,
                0.1326107887439663,
                0.1489245275752285,
                0.143107476635514,
                0.23189027677929547,
                0.1613818424566088,
                0.14889532196775188,
                0.10332622984492143,
                0.11940054688302351,
                0.13040496302762658,
                0.11472123087193181,
                0.15307044007394474,
                0.16371989575844717,
                0.1942898223272055,
                0.2230120930471398,
                0.1814679187634795,
                0.19069496508164732,
                0.07537197031940022,
            ],
        }
    },
    {  # test answers for 24.10
        "integration_classification_2d": {
            "losses": 0.7806512035761669,
            "best_metric": 0.9977695200407783,
            "infer_prop": [805, 727, 955, 1033, 321, 993],
        },
        "integration_workflows": {
            "best_metric": 0.9207136034965515,
            "best_metric_2": 0.9216295480728149,
            "infer_metric": 0.920440673828125,
            "infer_metric_2": 0.9203161001205444,
            "output_sums": [
                0.1423349380493164,
                0.15172767639160156,
                0.1382155418395996,
                0.13398218154907227,
                0.18552064895629883,
                0.16435527801513672,
                0.14128494262695312,
                0.16725540161132812,
                0.15690851211547852,
                0.17731285095214844,
                0.16189050674438477,
                0.16543960571289062,
                0.14431238174438477,
                0.11064529418945312,
                0.16129302978515625,
                0.1970067024230957,
                0.17503118515014648,
                0.053476810455322266,
                0.1914362907409668,
                0.2001795768737793,
                0.19636154174804688,
                0.2040243148803711,
                0.1606454849243164,
                0.13213014602661133,
                0.15132904052734375,
                0.1370987892150879,
                0.22805070877075195,
                0.16170072555541992,
                0.1477980613708496,
                0.10428047180175781,
                0.1195521354675293,
                0.13089942932128906,
                0.11238527297973633,
                0.15204906463623047,
                0.1603565216064453,
                0.19054937362670898,
                0.21789216995239258,
                0.17824840545654297,
                0.18654584884643555,
                0.03622245788574219,
            ],
            "output_sums_2": [
                0.1423349380493164,
                0.15172767639160156,
                0.1382155418395996,
                0.13398218154907227,
                0.18552064895629883,
                0.16435527801513672,
                0.14128494262695312,
                0.16725540161132812,
                0.15690851211547852,
                0.17731285095214844,
                0.16189050674438477,
                0.16543960571289062,
                0.14431238174438477,
                0.11064529418945312,
                0.16129302978515625,
                0.1970067024230957,
                0.17503118515014648,
                0.053476810455322266,
                0.1914362907409668,
                0.2001795768737793,
                0.19636154174804688,
                0.2040243148803711,
                0.1606454849243164,
                0.13213014602661133,
                0.15132904052734375,
                0.1370987892150879,
                0.22805070877075195,
                0.16170072555541992,
                0.1477980613708496,
                0.10428047180175781,
                0.1195521354675293,
                0.13089942932128906,
                0.11238527297973633,
                0.15204906463623047,
                0.1603565216064453,
                0.19054937362670898,
                0.21789216995239258,
                0.17824840545654297,
                0.18654584884643555,
                0.03622245788574219,
            ],
        },
    },
]


def test_integration_value(test_name, key, data, rtol=1e-2):
    for idx, expected in enumerate(EXPECTED_ANSWERS):
        if test_name not in expected:
            continue
        if key not in expected[test_name]:
            continue
        value = expected[test_name][key]
        if np.allclose(data, value, rtol=rtol):
            print(f"matched {idx} result of {test_name}, {key}, {rtol}.")
            return True
    raise ValueError(f"no matched results for {test_name}, {key}. {data}.")
